Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FSQ implementation #74

Merged
merged 14 commits into from
Sep 29, 2023

Conversation

sekstini
Copy link
Contributor

@sekstini sekstini commented Sep 28, 2023

TODO:

  • Verify correctness
  • Add usage example

Notes:

  • Torch doesn't support uint32 yet, so we use int32. Should be fine.
  • Part of the offset calculation (in the bound function) is missing. Just took copilot's suggestion for now.
  • Fixed some grammatical errors (indexes -> indices, and incorrect docstring)

@lucidrains
Copy link
Owner

nice! give it a test drive with the cifar script in the examples folder

@kashif ^

@sekstini
Copy link
Contributor Author

Still not 100% sure this is correct, but initial results seem promising.

 vq :: rec loss: 0.114 | cmt loss: 0.001 | active %: 21.094

fsq :: rec loss: 0.111 | active %: 58.333

Not parameter matched, so losses aren't really representative, but interesting to see the higher codebook usage.

@lucidrains
Copy link
Owner

do you have the training curves for each?

@sekstini
Copy link
Contributor Author

sekstini commented Sep 28, 2023

edit: removed the images that were here because the test was broken

@sekstini
Copy link
Contributor Author

sekstini commented Sep 28, 2023

Probably isn't representative of FSQ performance in general, but it does seem to be working on some level at least ^^

@sekstini
Copy link
Contributor Author

Ok wait, strike that. Reading the paper a bit more closely, the offsets should be flipped relative to what I put there initially. Doing so nets a much higher active %.

examples/autoencoder_fsq.py Outdated Show resolved Hide resolved
@sekstini
Copy link
Contributor Author

sekstini commented Sep 28, 2023

plot

Okay, this seems fine for an example. Curious to see if it works in more realistic scenarios too.

@lucidrains
Copy link
Owner

@sekstini getting there! we can let @kashif do a review before getting it merged! thanks for beasting through this

return z + (zhat - z).detach()


class FSQ(nn.Module):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for porting this! Do you mind if we link this repo in the next version and our own public code release?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LMK if you are also planning to update the README and I can send some figs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please do! and i believe @lucidrains and @sekstini will appreciate it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, go ahead 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fab-jul

LMK if you are also planning to update the README and I can send some figs.

That would be great 🙏

@fab-jul
Copy link

fab-jul commented Sep 29, 2023

plot

Okay, this seems fine for an example. Curious to see if it works in more realistic scenarios too.

Cool to see the high util out of the box :) I assume this is a fairly shallow AE? Probably the loss gap could be a bit smaller but looks like it's WAI! nice

@sekstini sekstini marked this pull request as ready for review September 29, 2023 07:20
@kashif
Copy link
Contributor

kashif commented Sep 29, 2023

@sekstini can you also kindly add a section in the README and the appropriate bibtex entry as #73 had started?

@sekstini
Copy link
Contributor Author

sekstini commented Sep 29, 2023

plot
Okay, this seems fine for an example. Curious to see if it works in more realistic scenarios too.

Cool to see the high util out of the box :) I assume this is a fairly shallow AE? Probably the loss gap could be a bit smaller but looks like it's WAI! nice

Yeah, the test models are both shallow, and tiny at ~10k parameters. Don't expect this to realistically reflect performance at all. This is the FSQ network.

@sekstini can you also kindly add a section in the README and the appropriate bibtex entry as #73 had started?

Will do 👍

@fab-jul
Copy link

fab-jul commented Sep 29, 2023

Yeah, the test models are both shallow, and tiny at ~10k parameters. Don't expect this to realistically reflect performance at all. This is the FSQ network.

Makes sense! Yeah one thing with FSQ is that because your codebook is fixed, you need some decent number of layers before and after the quantizer. Usually this is the case (eg the VQ-GANs people train are fairly deep).

Anyway, I think the results you see are what we expect! Good stuff.

For the figures, I revisited your README and it seems it is usually mostly the code, so maybe simply using the cube (PDF here) would be enough or already too much. Your call :)

In the README of our upcoming mini-repo, I also added the table, but again, this might be out of place in your README:

VQ FSQ
Quantization argmin_c || z-c || round(f(z))
Gradients Straight Through Estimation (STE) STE
Auxiliary Losses Commitment, codebook, entropy loss, ... N/A
Tricks EMA on codebook, codebook splitting, projections, ... N/A
Parameters Codebook N/A

source:

|                  | VQ | FSQ |
|------------------|----|-----|
| Quantization     | argmin_c \|\| z-c \|\| | round(f(z)) |
| Gradients        | Straight Through Estimation (STE) | STE |
| Auxiliary Losses | Commitment, codebook, entropy loss, ... | N/A |
| Tricks           | EMA on codebook, codebook splitting, projections, ...| N/A |
| Parameters       | Codebook | N/A |

@lucidrains
Copy link
Owner

lgtm! releasing

thank you @sekstini !

@lucidrains lucidrains merged commit 0cce037 into lucidrains:fsq Sep 29, 2023
@sekstini
Copy link
Contributor Author

@lucidrains Oh, apparently I pointed this as the fsq branch, so you might need to merge it into master

@lucidrains
Copy link
Owner

@sekstini yup, no problem, thank you! 🙏

@fab-jul
Copy link

fab-jul commented Sep 29, 2023

Thanks everyone! I'll add a link to this repo in our next revision

Added a link to the official README now.

@dribnet
Copy link

dribnet commented Oct 1, 2023

Appreciate this quick port and example code. Just thought I would add that checking reconstructions from the FashionMNIST example appear reasonable

figgy1

However a bit of surprise noticing it fails if switching back to MNIST

figgy2

Changing levels/seeds didn't help, so perhaps as @fab-jul mentioned it's just a case of needing more layers before/after quantizer for some cases.

@sekstini
Copy link
Contributor Author

sekstini commented Oct 1, 2023

@dribnet Interesting. Not sure why we would see this particular failure mode here, but I made a toy example of residual vector quantization where it's working: https://gist.github.com/sekstini/7f089f71d4b975ec8bde37d878b514d0.

residual_fsq

@lucidrains
Copy link
Owner

@dribnet adding more layers seems testable

@lucidrains
Copy link
Owner

LFQ https://arxiv.org/abs/2310.05737 looks simliar?

@lucidrains
Copy link
Owner

nevermind, it is slightly different, will be adding

@fab-jul
Copy link

fab-jul commented Oct 10, 2023

nevermind, it is slightly different, will be adding

It‘s FSQ with levels = 2 plus an entropy maximization loss, IIUC.

@lucidrains
Copy link
Owner

yes, FSQ generalizes it, save for the entropy loss. no ablation of the entropy loss, so unsure how necessary it is, but i'll add it. perhaps that will be their contribution

@lucidrains
Copy link
Owner

@fab-jul congrats either way! this could be a winner

@fab-jul
Copy link

fab-jul commented Oct 10, 2023

@lucidrains thanks!

for us, levels = 2 was suboptimal but I can see how an entropy maximization would fix that. I wonder if that also improves FSQ (although ofc the goal of our paper was to get rid of all aux losses haha)

@sekstini
Copy link
Contributor Author

@sekstini you should pair up with an audio researcher and think hard about this (residual fsq), maybe do a tweetstorm or blogpost if you see anything but do not have bandwidth to write a paper. i'm sure many will be thinking along these lines, looking at the new paper

Funny you mention it, as I'm experimenting with a vocoder based on this idea in this very moment ^^
I should definitely improve my tweeting game though.

@lucidrains
Copy link
Owner

@sekstini are you seeing good results?

thinking about a soundstream variation with multi-headed LFQ

@sekstini
Copy link
Contributor Author

@lucidrains

@sekstini are you seeing good results?

Not really, but definitely not because of FSQ.

I have exclusively been toying with various parallel encoding schemes, which I'm guessing are difficult for the model to learn, and I suspect residual quantization would work a lot better.

@lucidrains
Copy link
Owner

lucidrains commented Oct 16, 2023

@sekstini ah got it, thanks!

yea, to do residual, i think the codes will need to be scaled down an order of magnitude, or inputs scaled up (cube within cubes, if you can imagine it), but i haven't worked it out. it is probably a 3 month research project. somebody will def end up trying it..

@lucidrains
Copy link
Owner

@sekstini yea, if LFQ pans out over at magvit2, i'll do some improvisation here and maybe someone can do the hard experimental work.

@lucidrains
Copy link
Owner

@sekstini almost done! #80

@sekstini
Copy link
Contributor Author

@sekstini almost done! #80

Neat! I made some decent progress on my vocoder with "parallel FSQ", but I'd be interested in swapping this in to compare performance. Feel free to tag me when it's done.

@lucidrains
Copy link
Owner

ok it is done, integrated in soundstream over here

@lucidrains
Copy link
Owner

may be of interest! lucidrains/magvit2-pytorch#4

@mueller-franzes
Copy link

Hi,
Thanks for the implementation!
@sekstini or @fab-jul Could you perhaps briefly explain what the purpose of the "shift/offset" is?
The reason why I am asking: For levels=2 "shift" becomes infinite. Should it be (1 + eps) instead of (1-eps)?

@sekstini
Copy link
Contributor Author

@mueller-franzes You may find this comment by Fabian interesting.

Should it be (1 + eps) instead of (1-eps)?

Makes sense to me. I copied it directly from the paper, but at the time there was a tan instead of atanh there, so I didn't notice any issues while testing levels = 2.

As a side note, you may want to check out LFQ if you're interested in the levels = 2 case in particular.

@lucidrains
Copy link
Owner

@mueller-franzes @sekstini maybe we should just enforce odd levels for now for FSQ until a correction to the paper comes out?

@lucidrains
Copy link
Owner

and yea, agreed with checking out LFQ. so many groups seeing success with it

@sekstini
Copy link
Contributor Author

@mueller-franzes @sekstini maybe we should just enforce odd levels for now for FSQ until a correction to the paper comes out?

Other than the asymmetry being weird, levels > 2 seems fine (actually most of my code has been using even values).

I think switching to (1 + eps) or enforcing levels > 2 makes sense.

@mueller-franzes
Copy link

Thank you both for the super quick response! That's a good tip, I'll have a look at LFQ next.

@lucidrains
Copy link
Owner

@sekstini oh you are right, it is a levels == 2 problem

ok let's just go with 1 + eps! thank you both

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants